# 1. Encode scene
# 1a. inputs are an image and two boxes, the boxes are used to constrain the attention (or maybe pool features)
# 1b. positional embeddings are also needed
# 2. Encode memory
# 2a. encode each image and then pool features for entities and actions
# 2b. how position information is going to be used here?
# 3. Queries are formed by projecting the features from 2
# 4. The queries attend to the encoded scene, the entity queries attend to the respective pooled features
# 5. Pick queries are contrasted with every pixel in the features map, then we have a classification loss
# 5a. the feature map does not have the correct dimension...
# 6. The place query is concatenated to a pooled feature from the pick crop, then same as the pick handling
# 6b a simpler variant to try is to handle place same as pick
#
# memory loader
# inherit train-tester and handle memory

from functools import partial
import math

import torch
from torch import nn
import torchvision


class MAManipulator(nn.Module):
    """Memory-augmented Analogical Manipulator."""

    def __init__(self):
        super().__init__()
        # Scene encoder
        ckpt = torch.load('r-50-1000ep.pth.tar', map_location='cpu')
        moco = MoCo_ResNet(
            partial(
                torchvision.models.__dict__['resnet50'],
                zero_init_residual=True
            ),
            256, 4096, 1.0
        )
        moco.load_state_dict(ckpt['state_dict'])
        self.encoder = moco.base_encoder

        # Positional encodings
        self.pos_embedding = PositionEmbeddingSine(128, normalize=True)

        # Query projector
        self.query_proj_layer = nn.Linear(256, 256)

        # Decoder layers
        self.decoder = nn.ModuleList()
        for _ in range(3):
            self.decoder.append(BiDecoderLayer(
                256, n_heads=8, dim_feedforward=256
            ))

        # Decoder heads
        self.prediction_heads = nn.ModuleList()
        for _ in range(self.num_decoder_layers):
            self.prediction_heads.append(PredictionModule(256))

    def forward(self, scene, scene_pick_objs, scene_place_objs,
                mem, mem_pick_objs, mem_place_objs, mem_pick, mem_place):
        # Encode scene
        scene = self._encode_scene(scene)

        # Pool features
        scene_pick_objs = self._pool_feats(scene_pick_objs)
        scene_place_objs = self._pool_feats(scene_place_objs)

        # Add positional embeddings to scene
        scene = self._add_pos_embed(scene)

        # Initialize queries
        mem_pick_objs, mem_place_objs, mem_pick, mem_place = self._mem_encode(
            mem, mem_pick_objs, mem_place_objs, mem_pick, mem_place
        )

        # Decode and predict
        pick, place = [], []
        for k in range(len(self.decoder_layers)):
            pick_, place_ = self._decode(
                k, scene, scene_pick_objs, scene_place_objs,
                mem, mem_pick_objs, mem_place_objs, mem_pick, mem_place
            )
            pick.append(pick_)
            place.append(place_)
        return torch.stack(pick, 1), torch.stack(place, 1)

    def _encode_scene(self, scene):
        return self.encoder(scene)

    def _add_pos_embed(self, scene):
        return self.pos_embedding(scene) + scene

    def _pool_feats(self, scene, objs):
        return torchvision.ops.roi_align(scene, objs, 1)

    def _mem_encode(self, mem, mem_pick_objs, mem_place_objs,
                    mem_pick, mem_place):
        mem = self._encode_scene(mem)
        mem_pick_objs = self._pool_feats(mem, mem_pick_objs)
        mem_place_objs = self._pool_feats(mem, mem_place_objs)
        mem_pick = self._pool_feats(mem, torch.cat((mem_pick - eps, mem_pick + eps)))
        mem_place = self._pool_feats(mem, torch.cat((mem_place - eps, mem_place + eps)))
        return mem_pick_objs, mem_place_objs, mem_pick, mem_place

    def _decode(self, layer_ind, scene, scene_pick_objs, scene_place_objs,
                mem, mem_pick_objs, mem_place_objs, mem_pick, mem_place):
        pass


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding,
    very similar to the one
    used by the Attention is all you need paper,
    generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.scale = 2 * math.pi

    def forward(self, x):
        """Image x (B, F, H, W)."""
        not_mask = torch.ones_like(x[:, 0])
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, device=x.device).float()
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((
            pos_x[:, :, :, 0::2].sin(),
            pos_x[:, :, :, 1::2].cos()
        ), dim=4).flatten(3)
        pos_y = torch.stack((
            pos_y[:, :, :, 0::2].sin(),
            pos_y[:, :, :, 1::2].cos()
        ), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class MoCo(nn.Module):
    """
    Build a MoCo model with a base encoder, a momentum encoder, and two MLPs
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
        """
        dim: feature dimension (default: 256)
        mlp_dim: hidden dimension in MLPs (default: 4096)
        T: softmax temperature (default: 1.0)
        """
        super().__init__()

        self.T = T

        # build encoders
        self.base_encoder = base_encoder(num_classes=mlp_dim)
        self.momentum_encoder = base_encoder(num_classes=mlp_dim)

        self._build_projector_and_predictor_mlps(dim, mlp_dim)

        for param_b, param_m in zip(
                self.base_encoder.parameters(),
                self.momentum_encoder.parameters()
        ):
            param_m.data.copy_(param_b.data)  # initialize
            param_m.requires_grad = False  # not update by gradient

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim,
                   last_bn=True):
        mlp = []
        for i in range(num_layers):
            dim1 = input_dim if i == 0 else mlp_dim
            dim2 = output_dim if i == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if i < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        pass

    @torch.no_grad()
    def _update_momentum_encoder(self, m):
        """Momentum update of the momentum encoder"""
        for param_b, param_m in zip(
                self.base_encoder.parameters(),
                self.momentum_encoder.parameters()
        ):
            param_m.data = param_m.data * m + param_b.data * (1. - m)

    def contrastive_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        # gather all targets
        k = concat_all_gather(k)
        # Einstein sum is more intuitive
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
        N = logits.shape[0]  # batch size per GPU
        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

    def forward(self, x1, x2, m):
        """
        Input:
            x1: first views of images
            x2: second views of images
            m: moco momentum
        Output:
            loss
        """

        # compute features
        q1 = self.predictor(self.base_encoder(x1))
        q2 = self.predictor(self.base_encoder(x2))

        with torch.no_grad():  # no gradient
            self._update_momentum_encoder(m)  # update the momentum encoder

            # compute momentum features as targets
            k1 = self.momentum_encoder(x1)
            k2 = self.momentum_encoder(x2)

        return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)


class MoCo_ResNet(MoCo):

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        hidden_dim = self.base_encoder.fc.weight.shape[1]
        del self.base_encoder.fc, self.momentum_encoder.fc

        # projectors
        self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
        self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)

        # predictor
        self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False)


@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


class BiDecoderLayer(nn.Module):
    """Self->cross_o->cross_s layer for proposals."""

    def __init__(self, d_model, n_heads, dim_feedforward=2048, dropout=0.1):
        """Initialize layers, d_model is the encoder dimension."""
        super().__init__()

        # Self attention
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads,
            dropout=dropout
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # Cross attention to entities
        self.cross_o = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout
        )
        self.dropout_o = nn.Dropout(dropout)
        self.norm_o = nn.LayerNorm(d_model)

        # Cross attention to scene
        self.cross_s = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout
        )
        self.dropout_s = nn.Dropout(dropout)
        self.norm_s = nn.LayerNorm(d_model)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, obj_query, act_query, vis_feats, obj_feats,
                padding_mask, obj_key_padding_mask):
        """
        Forward pass.
        Args:
            query: (B, N, F)
            vis_feats: (B, V, F)
            obj_feats: (B, L, F)
            padding_mask: (B, N) (for query)
            obj_key_padding_mask: (B, L)
        Returns:
            query: (B, N, F)
        """
        query = torch.cat((obj_query, act_query), 1)
        query = query.transpose(0, 1)

        # Self attention
        query2 = self.self_attn(
            query, query, query,
            attn_mask=None,
            key_padding_mask=padding_mask
        )[0]
        query = self.norm1(query + self.dropout1(query2))

        # Cross attend to object entities
        query2 = self.cross_o(
            query=query,
            key=obj_feats.transpose(0, 1),
            value=obj_feats.transpose(0, 1),
            attn_mask=None,
            key_padding_mask=obj_key_padding_mask  # (B, L)
        )[0]
        query = self.norm_o(query + self.dropout_o(query2))

        # Cross attend to scene
        obj_query = query[:obj_query.size(1)]
        query = query[obj_query.size(1):]  # act query only
        query2 = self.cross_s(
            query=query,
            key=vis_feats.transpose(0, 1),
            value=vis_feats.transpose(0, 1),
            attn_mask=None,
            key_padding_mask=None
        )[0]
        query = self.norm_s(query + self.dropout_s(query2))

        # FFN
        query = torch.cat((obj_query, query))
        query = self.norm2(query + self.ffn(query))
        query = query.transpose(0, 1).contiguous()

        return query[:, :obj_query.size(1)], query[:, obj_query.size(1):]


class ThreeLayerMLP(nn.Module):
    """A 3-layer MLP with normalization and dropout."""

    def __init__(self, dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim, 1, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(dim, dim, 1, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(dim, out_dim, 1)
        )

    def forward(self, x):
        """Forward pass, x can be (B, dim, N)."""
        return self.net(x)


class PredictionModule(nn.Module):
    """Predicts pick/place."""

    def __init__(self, d_model):
        super().__init__()
        self.pick_module = ThreeLayerMLP(d_model, 2)
        self.place_module = ThreeLayerMLP(d_model, 2)
        self.rotation_module = ThreeLayerMLP(d_model, 36)

    def forward(self, pick_query, place_query):
        """Both queries (B, N, d_model)."""
        return (
            self.pick_module(pick_query),
            self.place_module(place_query),
            self.rotation_module(place_query)
        )
